{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Save & Restore a Model\n",
    "\n",
    "Save and Restore a model using TensorFlow.\n",
    "This example is using the MNIST database of handwritten digits\n",
    "(http://yann.lecun.com/exdb/mnist/).\n",
    "\n",
    "- Author: Aymeric Damien\n",
    "- Project: https://github.com/aymericdamien/TensorFlow-Examples/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "from __future__ import print_function\n",
    "\n",
    "# Import MINST data\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)\n",
    "\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parameters\n",
    "learning_rate = 0.001\n",
    "batch_size = 100\n",
    "display_step = 1\n",
    "model_path = \"/tmp/model.ckpt\"\n",
    "\n",
    "# Network Parameters\n",
    "n_hidden_1 = 256 # 1st layer number of features\n",
    "n_hidden_2 = 256 # 2nd layer number of features\n",
    "n_input = 784 # MNIST data input (img shape: 28*28)\n",
    "n_classes = 10 # MNIST total classes (0-9 digits)\n",
    "\n",
    "# tf Graph input\n",
    "x = tf.placeholder(\"float\", [None, n_input])\n",
    "y = tf.placeholder(\"float\", [None, n_classes])\n",
    "\n",
    "\n",
    "# Create model\n",
    "def multilayer_perceptron(x, weights, biases):\n",
    "    # Hidden layer with RELU activation\n",
    "    layer_1 = tf.add(tf.matmul(x, weights['h1']), biases['b1'])\n",
    "    layer_1 = tf.nn.relu(layer_1)\n",
    "    # Hidden layer with RELU activation\n",
    "    layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])\n",
    "    layer_2 = tf.nn.relu(layer_2)\n",
    "    # Output layer with linear activation\n",
    "    out_layer = tf.matmul(layer_2, weights['out']) + biases['out']\n",
    "    return out_layer\n",
    "\n",
    "# Store layers weight & bias\n",
    "weights = {\n",
    "    'h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),\n",
    "    'h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),\n",
    "    'out': tf.Variable(tf.random_normal([n_hidden_2, n_classes]))\n",
    "}\n",
    "biases = {\n",
    "    'b1': tf.Variable(tf.random_normal([n_hidden_1])),\n",
    "    'b2': tf.Variable(tf.random_normal([n_hidden_2])),\n",
    "    'out': tf.Variable(tf.random_normal([n_classes]))\n",
    "}\n",
    "\n",
    "# Construct model\n",
    "pred = multilayer_perceptron(x, weights, biases)\n",
    "\n",
    "# Define loss and optimizer\n",
    "cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))\n",
    "optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)\n",
    "\n",
    "# Initializing the variables\n",
    "init = tf.global_variables_initializer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 'Saver' op to save and restore all the variables\n",
    "saver = tf.train.Saver()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting 1st session...\n",
      "Epoch: 0001 cost= 187.778896380\n",
      "Epoch: 0002 cost= 42.367902536\n",
      "Epoch: 0003 cost= 26.488964058\n",
      "First Optimization Finished!\n",
      "Accuracy: 0.9075\n",
      "Model saved in file: /tmp/model.ckpt\n"
     ]
    }
   ],
   "source": [
    "# Running first session\n",
    "print(\"Starting 1st session...\")\n",
    "with tf.Session() as sess:\n",
    "    # Initialize variables\n",
    "    sess.run(init)\n",
    "\n",
    "    # Training cycle\n",
    "    for epoch in range(3):\n",
    "        avg_cost = 0.\n",
    "        total_batch = int(mnist.train.num_examples/batch_size)\n",
    "        # Loop over all batches\n",
    "        for i in range(total_batch):\n",
    "            batch_x, batch_y = mnist.train.next_batch(batch_size)\n",
    "            # Run optimization op (backprop) and cost op (to get loss value)\n",
    "            _, c = sess.run([optimizer, cost], feed_dict={x: batch_x,\n",
    "                                                          y: batch_y})\n",
    "            # Compute average loss\n",
    "            avg_cost += c / total_batch\n",
    "        # Display logs per epoch step\n",
    "        if epoch % display_step == 0:\n",
    "            print \"Epoch:\", '%04d' % (epoch+1), \"cost=\", \\\n",
    "                \"{:.9f}\".format(avg_cost)\n",
    "    print(\"First Optimization Finished!\")\n",
    "\n",
    "    # Test model\n",
    "    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))\n",
    "    # Calculate accuracy\n",
    "    accuracy = tf.reduce_mean(tf.cast(correct_prediction, \"float\"))\n",
    "    print(\"Accuracy:\", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))\n",
    "\n",
    "    # Save model weights to disk\n",
    "    save_path = saver.save(sess, model_path)\n",
    "    print(\"Model saved in file: %s\" % save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting 2nd session...\n",
      "Model restored from file: /tmp/model.ckpt\n",
      "Epoch: 0001 cost= 18.292712951\n",
      "Epoch: 0002 cost= 13.404136196\n",
      "Epoch: 0003 cost= 9.855191723\n",
      "Epoch: 0004 cost= 7.276933088\n",
      "Epoch: 0005 cost= 5.564581285\n",
      "Epoch: 0006 cost= 4.165259939\n",
      "Epoch: 0007 cost= 3.139393926\n",
      "Second Optimization Finished!\n",
      "Accuracy: 0.9385\n"
     ]
    }
   ],
   "source": [
    "# Running a new session\n",
    "print(\"Starting 2nd session...\")\n",
    "with tf.Session() as sess:\n",
    "    # Initialize variables\n",
    "    sess.run(init)\n",
    "\n",
    "    # Restore model weights from previously saved model\n",
    "    load_path = saver.restore(sess, model_path)\n",
    "    print(\"Model restored from file: %s\" % save_path)\n",
    "\n",
    "    # Resume training\n",
    "    for epoch in range(7):\n",
    "        avg_cost = 0.\n",
    "        total_batch = int(mnist.train.num_examples / batch_size)\n",
    "        # Loop over all batches\n",
    "        for i in range(total_batch):\n",
    "            batch_x, batch_y = mnist.train.next_batch(batch_size)\n",
    "            # Run optimization op (backprop) and cost op (to get loss value)\n",
    "            _, c = sess.run([optimizer, cost], feed_dict={x: batch_x,\n",
    "                                                          y: batch_y})\n",
    "            # Compute average loss\n",
    "            avg_cost += c / total_batch\n",
    "        # Display logs per epoch step\n",
    "        if epoch % display_step == 0:\n",
    "            print(\"Epoch:\", '%04d' % (epoch + 1), \"cost=\", \\\n",
    "                \"{:.9f}\".format(avg_cost))\n",
    "    print(\"Second Optimization Finished!\")\n",
    "\n",
    "    # Test model\n",
    "    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))\n",
    "    # Calculate accuracy\n",
    "    accuracy = tf.reduce_mean(tf.cast(correct_prediction, \"float\"))\n",
    "    print(\"Accuracy:\", accuracy.eval(\n",
    "        {x: mnist.test.images, y: mnist.test.labels}))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
